# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
import re
import sys

import torch
from huggingface_hub import snapshot_download
from peft import PeftModel
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    CLIPImageProcessor,
    CLIPVisionModel,
    GenerationConfig,
)
from transformers.generation.streamers import TextStreamer

from xtuner.dataset.utils import expand2square, load_image
from xtuner.model.utils import prepare_inputs_labels_for_multimodal
from xtuner.tools.utils import get_stop_criteria
from xtuner.utils import (
    DEFAULT_IMAGE_TOKEN,
    IMAGE_TOKEN_INDEX,
    PROMPT_TEMPLATE,
    SYSTEM_TEMPLATE,
)
from xtuner.utils.device import get_device

TORCH_DTYPE_MAP = dict(
    fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto="auto"
)


def remove_prefix(state_dict, prefix):
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith(prefix):
            new_key = key[len(prefix) :]
            new_state_dict[new_key] = value
        else:
            new_state_dict[key] = value
    return new_state_dict


def parse_args():
    parser = argparse.ArgumentParser(description="Chat with a HF model")
    parser.add_argument("model_name_or_path", help="Hugging Face model name or path")
    adapter_group = parser.add_mutually_exclusive_group()
    adapter_group.add_argument("--adapter", default=None, help="adapter name or path")
    adapter_group.add_argument("--llava", default=None, help="llava name or path")
    parser.add_argument(
        "--visual-encoder", default=None, help="visual encoder name or path"
    )
    parser.add_argument("--visual-select-layer", default=-2, help="visual select layer")
    parser.add_argument("--image", default=None, help="image")
    parser.add_argument(
        "--torch-dtype",
        default="fp16",
        choices=TORCH_DTYPE_MAP.keys(),
        help="Override the default `torch.dtype` and load the model under "
        "a specific `dtype`.",
    )
    parser.add_argument(
        "--prompt-template",
        choices=PROMPT_TEMPLATE.keys(),
        default=None,
        help="Specify a prompt template",
    )
    system_group = parser.add_mutually_exclusive_group()
    system_group.add_argument("--system", default=None, help="Specify the system text")
    system_group.add_argument(
        "--system-template",
        choices=SYSTEM_TEMPLATE.keys(),
        default=None,
        help="Specify a system template",
    )
    parser.add_argument(
        "--bits", type=int, choices=[4, 8, None], default=None, help="LLM bits"
    )
    parser.add_argument("--bot-name", type=str, default="BOT", help="Name for Bot")
    parser.add_argument(
        "--with-plugins",
        nargs="+",
        choices=["calculate", "solve", "search"],
        help="Specify plugins to use",
    )
    parser.add_argument(
        "--no-streamer", action="store_true", help="Whether to with streamer"
    )
    parser.add_argument("--lagent", action="store_true", help="Whether to use lagent")
    parser.add_argument(
        "--stop-words", nargs="+", type=str, default=[], help="Stop words"
    )
    parser.add_argument(
        "--offload-folder",
        default=None,
        help="The folder in which to offload the model weights (or where the "
        "model weights are already offloaded).",
    )
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=2048,
        help="Maximum number of new tokens allowed in generated text",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.1,
        help="The value used to modulate the next token probabilities.",
    )
    parser.add_argument(
        "--top-k",
        type=int,
        default=40,
        help="The number of highest probability vocabulary tokens to "
        "keep for top-k-filtering.",
    )
    parser.add_argument(
        "--top-p",
        type=float,
        default=0.75,
        help="If set to float < 1, only the smallest set of most probable "
        "tokens with probabilities that add up to top_p or higher are "
        "kept for generation.",
    )
    parser.add_argument(
        "--repetition-penalty",
        type=float,
        default=1.0,
        help="The parameter for repetition penalty. 1.0 means no penalty.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Random seed for reproducible text generation",
    )
    args = parser.parse_args()
    return args


def get_input():
    """Helper function for getting input from users."""
    sentinel = ""  # ends when this string is seen
    result = None
    while result is None:
        print(
            (
                "\ndouble enter to end input (EXIT: exit chat, "
                "RESET: reset history) >>> "
            ),
            end="",
        )
        try:
            result = "\n".join(iter(input, sentinel))
        except UnicodeDecodeError:
            print("Invalid characters detected. Please enter again.")
    return result


def main():
    args = parse_args()
    torch.manual_seed(args.seed)

    # build llm
    quantization_config = None
    load_in_8bit = False
    if args.bits == 4:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            load_in_8bit=False,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
    elif args.bits == 8:
        load_in_8bit = True
    model_kwargs = {
        "quantization_config": quantization_config,
        "load_in_8bit": load_in_8bit,
        "device_map": "auto",
        "offload_folder": args.offload_folder,
        "trust_remote_code": True,
        "torch_dtype": TORCH_DTYPE_MAP[args.torch_dtype],
    }
    if args.lagent:
        from lagent.actions import ActionExecutor, GoogleSearch
        from lagent.agents import (
            CALL_PROTOCOL_CN,
            FORCE_STOP_PROMPT_CN,
            ReAct,
            ReActProtocol,
        )
        from lagent.llms import HFTransformerCasualLM

        try:
            SERPER_API_KEY = os.environ["SERPER_API_KEY"]
        except Exception:
            print(
                "Please obtain the `SERPER_API_KEY` from https://serper.dev "
                "and set it using `export SERPER_API_KEY=xxx`."
            )
            sys.exit(1)

        model_kwargs.pop("trust_remote_code")
        llm = HFTransformerCasualLM(args.model_name_or_path, model_kwargs=model_kwargs)
        if args.adapter is not None:
            print(f"Loading adapter from {args.adapter}...")
            llm.model = PeftModel.from_pretrained(
                llm.model,
                args.adapter,
                offload_folder=args.offload_folder,
                trust_remote_code=True,
            )
        search_tool = GoogleSearch(api_key=SERPER_API_KEY)
        chatbot = ReAct(
            llm=llm,
            action_executor=ActionExecutor(actions=[search_tool]),
            protocol=ReActProtocol(
                call_protocol=CALL_PROTOCOL_CN, force_stop=FORCE_STOP_PROMPT_CN
            ),
        )
        while True:
            text = get_input()
            while text.strip() == "RESET":
                print("Log: History responses have been removed!")
                chatbot._session_history = []
                inputs = ""
                text = get_input()
            if text.strip() == "EXIT":
                print("Log: Exit!")
                exit(0)
            response = chatbot.chat(text)
            print(response.response)
    else:
        if args.with_plugins is None:
            inner_thoughts_open = False
            calculate_open = False
            solve_open = False
            search_open = False
        else:
            assert args.prompt_template == args.system_template == "moss_sft"
            from plugins import plugins_api

            inner_thoughts_open = True
            calculate_open = "calculate" in args.with_plugins
            solve_open = "solve" in args.with_plugins
            search_open = "search" in args.with_plugins
            # pre-import for api and model preparation
            if calculate_open:
                from plugins import calculate  # noqa: F401
            if solve_open:
                from plugins import solve  # noqa: F401
            if search_open:
                from plugins import search  # noqa: F401
        # build llm
        llm = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path, **model_kwargs
        )
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_name_or_path, trust_remote_code=True, encode_special_tokens=True
        )
        print(f"Load LLM from {args.model_name_or_path}")
        if args.adapter is not None:
            llm = PeftModel.from_pretrained(
                llm,
                args.adapter,
                offload_folder=args.offload_folder,
                trust_remote_code=True,
            )
            print(f"Load adapter from {args.adapter}")
        if args.llava is not None:
            llava_path = (
                snapshot_download(repo_id=args.llava)
                if not osp.isdir(args.llava)
                else args.llava
            )

            # build visual_encoder
            if "visual_encoder" in os.listdir(llava_path):
                assert args.visual_encoder is None, (
                    "Please don't specify the `--visual-encoder` since passed "
                    "`--llava` contains a visual encoder!"
                )
                visual_encoder_path = osp.join(llava_path, "visual_encoder")
            else:
                assert (
                    args.visual_encoder is not None
                ), "Please specify the `--visual-encoder`!"
                visual_encoder_path = args.visual_encoder
            visual_encoder = CLIPVisionModel.from_pretrained(
                visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype]
            )
            image_processor = CLIPImageProcessor.from_pretrained(visual_encoder_path)
            print(f"Load visual_encoder from {visual_encoder_path}")

            # load adapter
            if "llm_adapter" in os.listdir(llava_path):
                adapter_path = osp.join(llava_path, "llm_adapter")
                llm = PeftModel.from_pretrained(
                    llm,
                    adapter_path,
                    offload_folder=args.offload_folder,
                    trust_remote_code=True,
                )
                print(f"Load LLM adapter from {args.llava}")
            if "visual_encoder_adapter" in os.listdir(llava_path):
                adapter_path = osp.join(llava_path, "visual_encoder_adapter")
                visual_encoder = PeftModel.from_pretrained(
                    visual_encoder, adapter_path, offload_folder=args.offload_folder
                )
                print(f"Load visual_encoder adapter from {args.llava}")

            # build projector
            projector_path = osp.join(llava_path, "projector")
            projector = AutoModel.from_pretrained(
                projector_path,
                torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype],
                trust_remote_code=True,
            )
            print(f"Load projector from {args.llava}")

            projector.to(get_device())
            projector.eval()
            visual_encoder.to(get_device())
            visual_encoder.eval()

        llm.eval()

        if args.image is not None:
            image = load_image(args.image)
            image = expand2square(
                image, tuple(int(x * 255) for x in image_processor.image_mean)
            )
            image = image_processor.preprocess(image, return_tensors="pt")[
                "pixel_values"
            ][0]
            image = image.to(get_device()).unsqueeze(0).to(visual_encoder.dtype)
            visual_outputs = visual_encoder(image, output_hidden_states=True)
            pixel_values = projector(
                visual_outputs.hidden_states[args.visual_select_layer][:, 1:]
            )

        stop_words = args.stop_words
        sep = ""
        if args.prompt_template:
            template = PROMPT_TEMPLATE[args.prompt_template]
            stop_words += template.get("STOP_WORDS", [])
            sep = template.get("SEP", "")
        stop_criteria = get_stop_criteria(tokenizer=tokenizer, stop_words=stop_words)

        if args.no_streamer:
            streamer = None
        else:
            streamer = TextStreamer(tokenizer, skip_prompt=True)

        gen_config = GenerationConfig(
            max_new_tokens=args.max_new_tokens,
            do_sample=args.temperature > 0,
            temperature=args.temperature,
            top_p=args.top_p,
            top_k=args.top_k,
            repetition_penalty=args.repetition_penalty,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
            if tokenizer.pad_token_id is not None
            else tokenizer.eos_token_id,
        )

        n_turn = 0
        inputs = ""
        while True:
            text = get_input()
            while text.strip() == "RESET":
                print("Log: History responses have been removed!")
                n_turn = 0
                inputs = ""
                text = get_input()
            if text.strip() == "EXIT":
                print("Log: Exit!")
                exit(0)

            if args.image is not None and n_turn == 0:
                text = DEFAULT_IMAGE_TOKEN + "\n" + text

            if args.prompt_template:
                prompt_text = ""
                template = PROMPT_TEMPLATE[args.prompt_template]
                if "SYSTEM" in template and n_turn == 0:
                    system_text = None
                    if args.system_template is not None:
                        system_text = SYSTEM_TEMPLATE[args.system_template].format(
                            round=n_turn + 1, bot_name=args.bot_name
                        )
                    elif args.system is not None:
                        system_text = args.system
                    if system_text is not None:
                        prompt_text += template["SYSTEM"].format(
                            system=system_text, round=n_turn + 1, bot_name=args.bot_name
                        )
                prompt_text += template["INSTRUCTION"].format(
                    input=text, round=n_turn + 1, bot_name=args.bot_name
                )
                if args.prompt_template == args.system_template == "moss_sft":
                    if not inner_thoughts_open:
                        prompt_text.replace(
                            "- Inner thoughts: enabled.", "- Inner thoughts: disabled."
                        )
                    if not calculate_open:
                        prompt_text.replace(
                            ("- Calculator: enabled. API: " "Calculate(expression)"),
                            "- Calculator: disabled.",
                        )
                    if not solve_open:
                        prompt_text.replace(
                            "- Equation solver: enabled. API: Solve(equation)",
                            "- Equation solver: disabled.",
                        )
                    if not search_open:
                        prompt_text.replace(
                            "- Web search: enabled. API: Search(query)",
                            "- Web search: disabled.",
                        )
            else:
                prompt_text = text
            inputs += prompt_text
            if args.image is None:
                if n_turn == 0:
                    ids = tokenizer.encode(inputs, return_tensors="pt")
                else:
                    ids = tokenizer.encode(
                        inputs, return_tensors="pt", add_special_tokens=False
                    )

                if args.with_plugins is not None:
                    generate_output = llm.generate(
                        inputs=ids.to(get_device()),
                        generation_config=gen_config,
                        streamer=streamer,
                        stopping_criteria=stop_criteria,
                    ).cpu()
                    generate_output_text = tokenizer.decode(
                        generate_output[0][len(ids[0]) :]
                    )
                    if streamer is None:
                        end = "" if generate_output_text[-1] == "\n" else "\n"
                        print(generate_output_text, end=end)
                    pattern = r"<\|Commands\|>:(.*?)<eoc>"
                    command_text = ", ".join(re.findall(pattern, generate_output_text))
                    extent_text = plugins_api(
                        command_text,
                        calculate_open=calculate_open,
                        solve_open=solve_open,
                        search_open=search_open,
                    )
                    end = "" if extent_text[-1] == "\n" else "\n"
                    print(extent_text, end=end)
                    extent_text_ids = tokenizer.encode(
                        extent_text, return_tensors="pt", add_special_tokens=False
                    )
                    new_ids = torch.cat((generate_output, extent_text_ids), dim=1)

                    generate_output = llm.generate(
                        inputs=new_ids.to(get_device()),
                        generation_config=gen_config,
                        streamer=streamer,
                        stopping_criteria=stop_criteria,
                    )
                    if streamer is None:
                        output_text = tokenizer.decode(
                            generate_output[0][len(new_ids[0]) :]
                        )
                        end = "" if output_text[-1] == "\n" else "\n"
                        print(output_text, end=end)
                else:
                    generate_output = llm.generate(
                        inputs=ids.to(get_device()),
                        generation_config=gen_config,
                        streamer=streamer,
                        stopping_criteria=stop_criteria,
                    )
                    if streamer is None:
                        output_text = tokenizer.decode(
                            generate_output[0][len(ids[0]) :]
                        )
                        end = "" if output_text[-1] == "\n" else "\n"
                        print(output_text, end=end)
                inputs = tokenizer.decode(generate_output[0])
            else:
                chunk_encode = []
                for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
                    if idx == 0 and n_turn == 0:
                        cur_encode = tokenizer.encode(chunk)
                    else:
                        cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
                    chunk_encode.append(cur_encode)
                assert len(chunk_encode) == 2
                ids = []
                for idx, cur_chunk_encode in enumerate(chunk_encode):
                    ids.extend(cur_chunk_encode)
                    if idx != len(chunk_encode) - 1:
                        ids.append(IMAGE_TOKEN_INDEX)
                ids = torch.tensor(ids).to(get_device()).unsqueeze(0)
                mm_inputs = prepare_inputs_labels_for_multimodal(
                    llm=llm, input_ids=ids, pixel_values=pixel_values
                )

                generate_output = llm.generate(
                    **mm_inputs,
                    generation_config=gen_config,
                    streamer=streamer,
                    bos_token_id=tokenizer.bos_token_id,
                    stopping_criteria=stop_criteria,
                )
                if streamer is None:
                    output_text = tokenizer.decode(generate_output[0])
                    end = "" if output_text[-1] == "\n" else "\n"
                    print(output_text, end=end)
                inputs += tokenizer.decode(generate_output[0])
            n_turn += 1
            inputs += sep
            if len(generate_output[0]) >= args.max_new_tokens:
                print(
                    "Remove the memory of history responses, since "
                    f"it exceeds the length limitation {args.max_new_tokens}."
                )
                n_turn = 0
                inputs = ""


if __name__ == "__main__":
    main()
